{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Deep Knowledge Tracing\n",
    "\n",
    "This notebook will show you how to train and use the DKT.\n",
    "First, we will show how to get the data (here we use a0910 as the dataset).\n",
    "Then we will show how to train a DKT and perform the parameters persistence.\n",
    "At last, we will show how to load the parameters from the file and evaluate on the test dataset.\n",
    "\n",
    "The script version could be found in [DKT.py](DKT.py)"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Data Preparation\n",
    "\n",
    "Before we process the data, we need to first acquire the dataset which is shown in [prepare_dataset.ipynb](prepare_dataset.ipynb)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "reading data from ../../data/a0910c/train.json: 3966it [00:00, 23302.50it/s]\n",
      "batchify: 100%|██████████| 130/130 [00:00<00:00, 629.18it/s]\n",
      "reading data from ../../data/a0910c/valid.json: 472it [00:00, 27564.52it/s]\n",
      "e:\\program\\baize\\baize\\extlib\\bucketing.py:327: UserWarning: Some buckets are empty and will be removed. Unused bucket keys=[55, 58, 59, 61, 65, 69, 74, 76, 77, 79, 80, 88, 90, 94, 95, 96, 99]\n",
      "  warnings.warn('Some buckets are empty and will be removed. Unused bucket keys=%s' %\n",
      "batchify: 100%|██████████| 84/84 [00:00<00:00, 1241.57it/s]\n",
      "reading data from ../../data/a0910c/test.json: 1088it [00:00, 13857.38it/s]\n",
      "e:\\program\\baize\\baize\\extlib\\bucketing.py:327: UserWarning: Some buckets are empty and will be removed. Unused bucket keys=[73, 88]\n",
      "  warnings.warn('Some buckets are empty and will be removed. Unused bucket keys=%s' %\n",
      "batchify: 100%|██████████| 101/101 [00:00<00:00, 931.51it/s]\n"
     ]
    }
   ],
   "source": [
    "from XKT.DKT import etl\n",
    "batch_size = 32\n",
    "train = etl(\"../../data/a0910c/train.json\", batch_size=batch_size)\n",
    "valid = etl(\"../../data/a0910c/valid.json\", batch_size=batch_size)\n",
    "test = etl(\"../../data/a0910c/test.json\", batch_size=batch_size)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Training and Persistence"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "logger: <Logger model (INFO)>\n",
      "model_name: model\n",
      "model_dir: model\n",
      "begin_epoch: 0\n",
      "end_epoch: 2\n",
      "batch_size: 32\n",
      "save_epoch: 1\n",
      "optimizer: Adam\n",
      "optimizer_params: {'learning_rate': 0.001, 'wd': 0.0001, 'clip_gradient': 1}\n",
      "lr_params: {}\n",
      "train_select: None\n",
      "save_select: None\n",
      "ctx: cpu(0)\n",
      "train_ctx: None\n",
      "eval_ctx: None\n",
      "toolbox_params: {}\n",
      "hyper_params: {'ku_num': 146, 'hidden_num': 100}\n",
      "init_params: {}\n",
      "loss_params: {}\n",
      "caption: \n",
      "validation_result_file: model\\result.json\n",
      "cfg_path: model\\configuration.json\n",
      "Epoch| Total-E          Batch     Total-B       Loss-SequenceLogisticMaskLoss             Progress           \n",
      "    0|       1            130         130                            0.639498     [00:04<00:00, 30.63it/s]   \n",
      "Epoch [0]\tLoss - SequenceLogisticMaskLoss: 0.639498\n",
      "           precision    recall        f1  support\n",
      "0.0         0.427845  0.205795  0.277913     7765\n",
      "1.0         0.689022  0.864755  0.766951    15801\n",
      "macro_avg   0.558433  0.535275  0.522432    23566\n",
      "accuracy: 0.647628\tmacro_auc: 0.562898\tmacro_aupoc: 0.713164\n",
      "Epoch| Total-E          Batch     Total-B       Loss-SequenceLogisticMaskLoss             Progress           \n",
      "    1|       1            130         130                            0.624943     [00:04<00:00, 31.97it/s]   \n",
      "Epoch [1]\tLoss - SequenceLogisticMaskLoss: 0.624943\n",
      "           precision    recall        f1  support\n",
      "0.0         0.472871  0.153767  0.232070     7765\n",
      "1.0         0.687705  0.915765  0.785517    15801\n",
      "macro_avg   0.580288  0.534766  0.508793    23566\n",
      "accuracy: 0.664686\tmacro_auc: 0.579386\tmacro_aupoc: 0.722303\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "evaluating: 100%|██████████| 84/84 [00:00<00:00, 130.42it/s]\n",
      "evaluating: 100%|██████████| 84/84 [00:00<00:00, 130.78it/s]\n",
      "model, INFO writing configuration parameters to G:\\program\\XKT\\examples\\DKT\\dkt\\configuration.json\n"
     ]
    },
    {
     "data": {
      "text/plain": "'dkt'"
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from XKT import DKT\n",
    "model = DKT(hyper_params=dict(ku_num=146, hidden_num=100))\n",
    "model.train(train, valid, end_epoch=2)\n",
    "model.save(\"dkt\")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Loading and Testing"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "evaluating: 100%|██████████| 101/101 [00:00<00:00, 113.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "           precision    recall        f1  support\n",
      "0.0         0.484619  0.157390  0.237611    17517\n",
      "1.0         0.670330  0.911000  0.772351    32944\n",
      "macro_avg   0.577475  0.534195  0.504981    50461\n",
      "accuracy: 0.649393\tmacro_auc: 0.570926\tmacro_aupoc: 0.702939\n"
     ]
    }
   ],
   "source": [
    "model = DKT.from_pretrained(\"dkt\")\n",
    "print(model.eval(test))\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Predict"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import mxnet as mx\n",
    "inputs = mx.nd.ones((2, 3))  # (2 students, 3 steps)\n",
    "outputs, _ = model(inputs)\n",
    "outputs.shape"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}